/* -*- C -*- */
/*
 * Copyright (c) 2007 Massachusetts Institute of Technology
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 */

#include <stdlib.h>
#include <spu_intrinsics.h>
#include <spu_mfcio.h>
#include "fftw-spu.h"
#include "../fftw-cell.h"

#define MAX_DMA_SIZE 16384
#define DMA_WAIT mfc_read_tag_status_all

/* in-place transpose, n x n square complex matrix */
static void complex_transpose(R *A, int lda, int n)
{
     int i, j;

     if (VL == 2) { /* single precision */
	  for (i = 0; i < n; i += 2) {
	       R *Ai = A + i * lda;
	       for (j = i; j < n - 2; j += 4) {
		    R *Aj = A + j * lda;
		    R *Aij = Ai + 2 * j;
		    R *Aji = Aj + 2 * i;
		    V aj0i0 = LD(Aji + 0 * lda, 0, 0);
		    V aj1i0 = LD(Aji + 1 * lda, 0, 0);
		    V aj2i0 = LD(Aji + 2 * lda, 0, 0);
		    V aj3i0 = LD(Aji + 3 * lda, 0, 0);
		    V ai0j0 = LD(Aij + 0 * lda + 0, 0, 0);
		    V ai1j0 = LD(Aij + 1 * lda + 0, 0, 0);
		    V ai0j2 = LD(Aij + 0 * lda + 4, 0, 0);
		    V ai1j2 = LD(Aij + 1 * lda + 4, 0, 0);
		    STN2(Aij + 0 * lda + 0, aj0i0, aj1i0, lda);
		    STN2(Aij + 0 * lda + 4, aj2i0, aj3i0, lda);
		    STN2(Aji + 0 * lda, ai0j0, ai1j0, lda);
		    STN2(Aji + 2 * lda, ai0j2, ai1j2, lda);
	       }
	       for (; j < n; j += 2) {
		    R *Aj = A + j * lda;
		    R *Aij = Ai + 2 * j;
		    R *Aji = Aj + 2 * i;
		    V aj0i0 = LD(Aji + 0 * lda, 0, 0);
		    V aj1i0 = LD(Aji + 1 * lda, 0, 0);
		    V ai0j0 = LD(Aij + 0 * lda, 0, 0);
		    V ai1j0 = LD(Aij + 1 * lda, 0, 0);
		    STN2(Aij + 0 * lda, aj0i0, aj1i0, lda);
		    STN2(Aji + 0 * lda, ai0j0, ai1j0, lda);
	       }
	  }

     } else { /* double precision */
	  for (i = 0; i < n; ++i) {
	       R *Ai = A + i * lda;
	       for (j = i; j < n - 3; j += 4) {
		    R *Aj = A + j * lda;
		    R *Aij = Ai + 2 * j;
		    R *Aji = Aj + 2 * i;
		    V ai0j0 = LD(Aij + (0 * lda), 0, 0);
		    V ai0j1 = LD(Aij + (0 * lda + 2), 0, 0);
		    V ai0j2 = LD(Aij + (0 * lda + 4), 0, 0);
		    V ai0j3 = LD(Aij + (0 * lda + 6), 0, 0);
		    V aj0i0 = LD(Aji + (0 * lda), 0, 0);
		    V aj1i0 = LD(Aji + (1 * lda), 0, 0);
		    V aj2i0 = LD(Aji + (2 * lda), 0, 0);
		    V aj3i0 = LD(Aji + (3 * lda), 0, 0);
		    ST(Aji + (0 * lda), ai0j0, 0, 0);
		    ST(Aji + (1 * lda), ai0j1, 0, 0);
		    ST(Aji + (2 * lda), ai0j2, 0, 0);
		    ST(Aji + (3 * lda), ai0j3, 0, 0);
		    ST(Aij + (0 * lda), aj0i0, 0, 0);
		    ST(Aij + (0 * lda + 2), aj1i0, 0, 0);
		    ST(Aij + (0 * lda + 4), aj2i0, 0, 0);
		    ST(Aij + (0 * lda + 6), aj3i0, 0, 0);
	       }
	       for (; j < n; ++j) {
		    R *Aj = A + j * lda;
		    R *Aij = Ai + 2 * j;
		    R *Aji = Aj + 2 * i;
		    V ai0j0 = LD(Aij + (0 * lda), 0, 0);
		    V aj0i0 = LD(Aji + (0 * lda), 0, 0);
		    ST(Aji + (0 * lda), ai0j0, 0, 0);
		    ST(Aij + (0 * lda), aj0i0, 0, 0);
	       }
	  }
     }
}


static void spu_dma1d_nowait(void *spu_addr, 
			     unsigned long long ppu_addr, 
			     size_t sz, unsigned int cmd)
{
     unsigned int tag_id = 0;

     while (sz > 0) {
	  /* select chunk to align ppu_addr */
	  size_t chunk = ALIGNMENT - (ppu_addr & (ALIGNMENT - 1));

	  /* if already aligned, transfer the whole thing */
	  if (chunk == ALIGNMENT || chunk > sz) 
	       chunk = sz;

	  /* ...up to MAX_DMA_SIZE */
	  if (chunk > MAX_DMA_SIZE) 
	       chunk = MAX_DMA_SIZE;

	  spu_mfcdma64(spu_addr, mfc_ea2h(ppu_addr), mfc_ea2l(ppu_addr),
		       chunk, tag_id, cmd);
	  sz -= chunk; ppu_addr += chunk; spu_addr += chunk;
     }
}

static void dma2d_contig_nowait(R *A, unsigned long long ppu_addr,
				int n, 
				int v, int spu_vstride, int ppu_vstride_bytes,
				unsigned int cmd)
{
     int vv;
     for (vv = 0; vv < v; ++vv) {
	  spu_dma1d_nowait(A, ppu_addr, 2 * sizeof(R) * n, cmd);
	  A += spu_vstride;
	  ppu_addr += ppu_vstride_bytes;
     }
}

void X(spu_dma1d)(void *spu_addr, unsigned long long ppu_addr, size_t sz,
		  unsigned int cmd)
{
     spu_dma1d_nowait(spu_addr, ppu_addr, sz, cmd);
     DMA_WAIT();
}

enum operation { DO_DMA, DO_TRANSPOSE };

/* algorithm for transposing rectangular matrices from ppu memory to
   SPU local store and vice versa.

   The algorithm cut a maximal square region of the matrix, copies the
   square into local store using contiguous DMA, transposese the
   square piece in-place, and tail-recurs on the rest of the matrix.

   For reasons not well understood, performance seems better if we
   execute all the DMAs first and then all the transpositions.  This
   needs further investigation.
*/
static void dma2d_transposed(R *A, int lda, 
			     unsigned long long ppu_addr, int ppu_stride_bytes,
			     int r, int c,
			     unsigned int cmd, 
			     enum operation op)
{
 tail:
     if (c > 0) {
	  if (r > c) {
	       /* cut a CxC chunk */
	       if (op == DO_DMA)
		    dma2d_contig_nowait(A, ppu_addr, c, c,
					lda, ppu_stride_bytes, cmd);
	       else
		    complex_transpose(A, lda, c);

	       r -= c;
	       A += lda * c;
	       ppu_addr += 2 * sizeof(R) * c;
	       goto tail;
	  } else if (r > 0) {
	       /* cut a RxR chunk */
	       if (op == DO_DMA)
		    dma2d_contig_nowait(A, ppu_addr, r, r, 
					lda, ppu_stride_bytes, cmd);
	       else
		    complex_transpose(A, lda, r);

	       c -= r;
	       A += 2 * r;
	       ppu_addr += ppu_stride_bytes * r;
	       goto tail;
	  }
     }
}

/* 2D dma transfer routine, works for 
   ppu_stride_bytes == 2 * sizeof(R) or ppu_vstride_bytes == 2 * sizeof(R) */
void X(spu_dma2d)(R *A, unsigned long long ppu_addr, 
		  int n, /* int spu_stride = 2 , */ int ppu_stride_bytes,
		  int v, /* int spu_vstride = 2 * n, */
		  int ppu_vstride_bytes,
		  unsigned int cmd)
{
     if (ppu_stride_bytes == 2 * sizeof(R)) { 
	  /* contiguous array on the PPU side */

	  /* if the input is a 1D contiguous array, collapse n into v */
	  if (ppu_vstride_bytes == ppu_stride_bytes * n) {
	       n *= v;
	       v = 1;
	  }

	  dma2d_contig_nowait(A, ppu_addr, n, v, 
			      2 * n, ppu_vstride_bytes, cmd);
	  DMA_WAIT();
     } else { 
	  /* ppu_vstride_bytes == 2 * sizeof(R), use transposed DMA */
	  if (cmd == MFC_PUT_CMD) 
	       dma2d_transposed(A, 2 * n, ppu_addr, ppu_stride_bytes,
				v, n, cmd, DO_TRANSPOSE);

	  dma2d_transposed(A, 2 * n, ppu_addr, ppu_stride_bytes, 
			   v, n, cmd, DO_DMA);
	  DMA_WAIT();

	  if (cmd == MFC_GET_CMD) 
	       dma2d_transposed(A, 2 * n, ppu_addr, ppu_stride_bytes,
				v, n, cmd, DO_TRANSPOSE);
     }
}
